""" This code does the robustness for myopic wells and then draws a table.

Inputs:
    - Must first run the simulated method of moments robustness checks on the cluster

"""

#%% IMPORTS -----------------------------------------------------------------------------
import pandas as pd
import copy
import numpy as np
from scipy.stats import norm as norm
import sys
sys.path.append('./')

from src.models_new import simulation, surplus
from src.run_scripts import utils

try:
    overleaf_path = sys.argv[1]
    root_path = sys.argv[2]
except:
    overleaf_path = './../Apps/Overleaf/bbm/draft/reports/revision/'
    root_path = './'

conversion_factor = 30
do_myopic_optimization = False
welfare_graph = True
mri_max = 2.15

p_exit_runs = [0.75, 0.95, 1.0]
opex_daily = 0.032

# %% READ IN THE INPUTS -----------------------------------------------------------------
tex_input = dict()
pct_by_exit = dict()
entry_cost_by_exit = dict()
for p_exit in p_exit_runs:
    data, delta, rho, c, weights = utils.read_in_data(
        path_moments_data='./models/smm_input/moments_empirical.csv',
        path_n_rigs="./models/first_stage/n_rigs",
        path_surplus_components="./models/surplus/surplus_components",
        path_surplus_grid='./models/surplus/surplus_grid_2_low_month.npy',
        path_df_state='./data_py/processed/states',
        path_delta='./models/smm_input/delta.csv',
        path_rho='./models/smm_input/rho.csv',
        path_df_contracts='./data_py/processed/contracts_final.csv',
        path_price_match_values='./models/price_match/price_match_values',
        path_coefs_data='./models/smm_input/coefs_data',
        path_prob_match_predict_contracts='./models/robustness/prob_match_predict_contracts',
        path_prob_match_predict='./models/robustness/prob_match_predict',
        time_period='month',
        p_exit=p_exit,
        use_myopic=do_myopic_optimization
    )

    if p_exit == 1.0:
        path_load = f'./models/smm/params_smm_with_diff_new.csv'
    else:
        path_load = f'./models/smm/params_smm_with_diff_new_myopic_{p_exit}.csv'

    params = pd.read_csv(
        path_load,
        squeeze=True,
        index_col=[0]
    ).to_dict()
    params['a_0'] = np.array(
        [params['a_0_low'], params['a_0_low'], params['a_0_low']])
    params['a_1'] = np.array(
        [params['a_1_low'], params['a_1_mid'], params['a_1_high']])
    params['denom'] = (
            norm.cdf(mri_max, loc=params['mu_0'], scale=params['sigma_0'])
            - norm.cdf(0, loc=params['mu_0'], scale=params['sigma_0'])
    )
    params['p_4'] = 1 - params['p_3'] - params['p_2']
    #params['delta'] = pd.read_csv(
     #   f'./models/robustness/delta_adjusted_{p_exit}.csv',
      #  index_col=[0]
    #)['0'].loc[0]

    #%% 3. Get updated surplus
    (
        surplus_values_by_tau_spec_non_myopic,
        well_outside_option_by_tau_spec
    ) = surplus.build_fast_surplus(
        data['match_values_by_tau_spec'],
        params,
        data['surplus_grid'],
        data['non_myopic_dict']
    )

    #%% DO SIMULATION -----------------------------------------------------------------------
    results_benchmark = simulation.do_simulation(
        data['state_data'],
        params,
        data['surplus_grid'],
        surplus_values_by_tau_spec_non_myopic,
        data['n_rigs'],
        mri_max=2.15,
        entry_prob_by_tau_by_ym=None,
        verbose=True,
        value_zero=False
    )
    moments_by_state_benchmark = pd.DataFrame(results_benchmark['moments_by_state'])

    # Get the aggregate moments from the simulation
    moments_sim_benchmark = simulation.get_aggregated_moments(
        moments_by_state_benchmark, data['g_data'], data['n_rigs'])

    #%% DO MYOPIC COUNTERFACTUAL ------------------------------------------------------------
    params['gamma'] = 0.0
    params['gamma_negative'] = 0.0
    # params['delta'] = 0.35708

    results_myopic = simulation.do_simulation(
        data['state_data'],
        params,
        data['surplus_grid'],
        surplus_values_by_tau_spec_non_myopic,
        data['n_rigs'],
        mri_max=2.15,
        value_zero=True,
        entry_prob_by_tau_by_ym=results_benchmark['entry_prob_by_tau_by_ym'],
        verbose=True)

    moments_by_state_myopic = pd.DataFrame(results_myopic['moments_by_state'])

    # Get the aggregate moments from the simulation
    moments_sim_myopic = simulation.get_aggregated_moments(
        moments_by_state_myopic, data['g_data'], data['n_rigs'])

    #%% COMPUTE WELFARE ---------------------------------------------------------------------
    total_value_by_spec = simulation.get_welfare_from_simulation(
        results_benchmark['state_detail_by_ym'],
        params,
        mri_max,
        data['n_rigs'],
        data['state_data']
    )
    total_value_by_spec_myopic = simulation.get_welfare_from_simulation(
        results_myopic['state_detail_by_ym'], params, mri_max,
        data['n_rigs'], data['state_data'])

    #%% GET DIFFERENCES ---------------------------------------------------------------------
    # Get entry costs
    df_shares = pd.DataFrame.from_dict(np.array(results_benchmark['shares_by_state']))
    df_shares['no_enter'] = 1 - df_shares[0] - df_shares[1] - df_shares[2]
    df_shares['enter_approx'] = (
        params['d_0'] + data['state_data']['g'] * params['d_1']) * (1 - df_shares['no_enter'])
    total_entry_cost = df_shares['enter_approx'].sum() * params['c']

    df_shares_non_myopic = pd.DataFrame.from_dict(np.array(results_myopic['shares_by_state']))
    df_shares_non_myopic['no_enter'] = 1 - df_shares_non_myopic[0] - df_shares_non_myopic[1] - df_shares_non_myopic[2]
    df_shares_non_myopic['enter_approx'] = (
        params['d_0'] + data['state_data']['g'] * params['d_1']) * (1 - df_shares_non_myopic['no_enter'])
    total_entry_cost_myopic = df_shares_non_myopic['enter_approx'].sum() * params['c']

    df_shares_myopic = pd.DataFrame.from_dict(np.array(results_myopic['shares_by_state']))
    df_shares_myopic['no_enter'] = 1 - df_shares_myopic[0] - df_shares_myopic[1] - df_shares_myopic[2]
    df_shares_myopic['enter_approx'] = (
        params['d_0'] + data['state_data']['g'] * params['d_1']) * (1 - df_shares_myopic['no_enter'])
    total_entry_cost_myopic = df_shares_myopic['enter_approx'].sum() * params['c']

    # Record entry cost
    entry_cost_by_exit[p_exit] = copy.deepcopy(total_entry_cost)

    # Do other things...
    total = 0
    total_by_ym = 0
    total_myopic = 0
    total_myopic_by_ym = 0
    total_non_myopic = 0
    total_non_myopic_by_ym = 0
    for spec in ['low', 'mid', 'high']:
        total += total_value_by_spec[spec].sum()
        total_by_ym += total_value_by_spec[spec]

        total_myopic += total_value_by_spec_myopic[spec].sum()
        total_myopic_by_ym += total_value_by_spec_myopic[spec]

        total_non_myopic += total_value_by_spec[spec].sum()
        total_non_myopic_by_ym += total_value_by_spec_myopic[spec]

#%% COMBINE RESULTS -----------------------------------------------------------------
    total_opex = 120 * (
            data['n_rigs']['low'] + data['n_rigs']['mid'] + data['n_rigs']['high']) * opex_daily
    total_myopic_sum = total_myopic_by_ym.sum() - total_entry_cost_myopic
    diff_0 = (total_by_ym - total_myopic_by_ym) / ((total_myopic_sum - total_opex) / 120)
    diff_1 = (total_non_myopic_by_ym - total_myopic_by_ym) / ((total_myopic_sum - total_opex) / 120)

    pct_by_exit[f'sorting_{int(p_exit * 100)}'] = round(
        100 * (total_by_ym - total_myopic_by_ym).sum() / (total_myopic_by_ym.sum() - total_entry_cost_myopic - total_opex),
        1
    )

    tex_input[f'sorting_{int(p_exit * 100)}'] = round(
        (total_by_ym - total_myopic_by_ym).sum() * conversion_factor, 1
    )

#%% ADD INTO A TABLE --------------------------------------------------------------------
with open('./src/tex/table_robust_p_exit.tex', 'r') as f:
    tex = f.read()
    output = tex.format(**pct_by_exit)
with open(overleaf_path + 'tables/table_robust_p_exit.tex', 'w') as f:
    f.write(output)

